{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Comparing randomized search and grid search for hyperparameter estimation"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Adapted from http://scikit-learn.org/stable/auto_examples/model_selection/randomized_search.html"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Compare randomized search and grid search for optimizing hyperparameters of a random forest. All parameters that influence the learning are searched simultaneously (except for the number of estimators, which poses a time / quality tradeoff).\n",
    "\n",
    "The randomized search and the grid search explore exactly the same space of parameters. The result in parameter settings is quite similar, while the run time for randomized search is drastically lower.\n",
    "\n",
    "The performance is slightly worse for the randomized search, though this is most likely a noise effect and would not carry over to a held-out test set.\n",
    "\n",
    "Note that in practice, one would not search over this many different parameters simultaneously using grid search, but pick only the ones deemed most important."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\u001b[1m\u001b[34mINFO: Recompiling stale cache file /Users/cedric/.julia/lib/v0.5/ScikitLearn.ji for module ScikitLearn.\n",
      "\u001b[0mWARNING: Method definition require(Symbol) in module Base at loading.jl:345 overwritten in module Main at /Users/cedric/.julia/v0.5/Requires/src/require.jl:12.\n"
     ]
    }
   ],
   "source": [
    "using ScikitLearn, Random, Printf, Statistics\n",
    "using PyCall: @pyimport\n",
    "\n",
    "@pyimport scipy.stats as stats\n",
    "using ScikitLearn.GridSearch: GridSearchCV, RandomizedSearchCV\n",
    "\n",
    "@sk_import datasets: load_digits\n",
    "@sk_import ensemble: RandomForestClassifier"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "RandomizedSearchCV took 8.45 seconds for 20 candidates, parameter settings.\n",
      "Model with rank:1\n",
      "Mean validation score: 0.927 (std: 0.012)\n",
      "Parameters: Dict{Symbol,Any}(Pair{Symbol,Any}(:max_features,6),Pair{Symbol,Any}(:bootstrap,false),Pair{Symbol,Any}(:min_samples_split,7),Pair{Symbol,Any}(:max_depth,nothing),Pair{Symbol,Any}(:criterion,\"gini\"),Pair{Symbol,Any}(:min_samples_leaf,4))\n",
      "\n",
      "Model with rank:2\n",
      "Mean validation score: 0.922 (std: 0.007)\n",
      "Parameters: Dict{Symbol,Any}(Pair{Symbol,Any}(:max_features,4),Pair{Symbol,Any}(:bootstrap,false),Pair{Symbol,Any}(:min_samples_split,8),Pair{Symbol,Any}(:max_depth,nothing),Pair{Symbol,Any}(:criterion,\"entropy\"),Pair{Symbol,Any}(:min_samples_leaf,6))\n",
      "\n",
      "Model with rank:3\n",
      "Mean validation score: 0.916 (std: 0.023)\n",
      "Parameters: Dict{Symbol,Any}(Pair{Symbol,Any}(:max_features,10),Pair{Symbol,Any}(:bootstrap,false),Pair{Symbol,Any}(:min_samples_split,8),Pair{Symbol,Any}(:max_depth,nothing),Pair{Symbol,Any}(:criterion,\"gini\"),Pair{Symbol,Any}(:min_samples_leaf,8))\n",
      "\n",
      "GridSearchCV took 59.64 seconds for 216 candidate parameter settings.\n",
      "Model with rank:1\n",
      "Mean validation score: 0.933 (std: 0.009)\n",
      "Parameters: Dict{Symbol,Any}(Pair{Symbol,Any}(:max_features,10),Pair{Symbol,Any}(:bootstrap,false),Pair{Symbol,Any}(:min_samples_split,3),Pair{Symbol,Any}(:max_depth,nothing),Pair{Symbol,Any}(:min_samples_leaf,1),Pair{Symbol,Any}(:criterion,\"entropy\"))\n",
      "\n",
      "Model with rank:2\n",
      "Mean validation score: 0.932 (std: 0.014)\n",
      "Parameters: Dict{Symbol,Any}(Pair{Symbol,Any}(:max_features,10),Pair{Symbol,Any}(:bootstrap,false),Pair{Symbol,Any}(:min_samples_split,2),Pair{Symbol,Any}(:max_depth,nothing),Pair{Symbol,Any}(:min_samples_leaf,3),Pair{Symbol,Any}(:criterion,\"entropy\"))\n",
      "\n",
      "Model with rank:3\n",
      "Mean validation score: 0.932 (std: 0.014)\n",
      "Parameters: Dict{Symbol,Any}(Pair{Symbol,Any}(:max_features,10),Pair{Symbol,Any}(:bootstrap,false),Pair{Symbol,Any}(:min_samples_split,3),Pair{Symbol,Any}(:max_depth,nothing),Pair{Symbol,Any}(:min_samples_leaf,3),Pair{Symbol,Any}(:criterion,\"entropy\"))\n",
      "\n"
     ]
    }
   ],
   "source": [
    "# get some data\n",
    "digits = load_digits()\n",
    "X, y = digits[\"data\"], digits[\"target\"]\n",
    "\n",
    "# build a classifier\n",
    "clf = RandomForestClassifier(n_estimators=20, random_state=2)\n",
    "\n",
    "# Utility function to report best scores\n",
    "function report(grid_scores, n_top=3)\n",
    "    top_scores = sort(grid_scores, by=x->x.mean_validation_score, rev=true)[1:n_top]\n",
    "    for (i, score) in enumerate(top_scores)\n",
    "        println(\"Model with rank:$i\")\n",
    "        @printf(\"Mean validation score: %.3f (std: %.3f)\\n\",\n",
    "                score.mean_validation_score,\n",
    "                std(score.cv_validation_scores))\n",
    "        println(\"Parameters: $(score.parameters)\")\n",
    "        println(\"\")\n",
    "    end\n",
    "end\n",
    "\n",
    "sampler(a, b) = stats.randint(a, b)\n",
    "#sampler(a, b) = DiscreteUniform(a, b-1)  TODO\n",
    "\n",
    "# specify parameters and distributions to sample from\n",
    "param_dist = Dict(\"max_depth\"=> [3, nothing],\n",
    "                  \"max_features\"=> sampler(1, 11),\n",
    "                  \"min_samples_split\"=> sampler(2, 11),\n",
    "                  \"min_samples_leaf\"=> sampler(1, 11),\n",
    "                  \"bootstrap\"=> [true, false],\n",
    "                  \"criterion\"=> [\"gini\", \"entropy\"])\n",
    "\n",
    "# run randomized search\n",
    "n_iter_search = 20\n",
    "random_search = RandomizedSearchCV(clf, param_dist,\n",
    "                                   n_iter=n_iter_search, random_state=MersenneTwister(42))\n",
    "\n",
    "start = time()\n",
    "fit!(random_search, X, y)\n",
    "@printf(\"RandomizedSearchCV took %.2f seconds for %d candidates, parameter settings.\\n\", (time() - start), n_iter_search)\n",
    "\n",
    "report(random_search.grid_scores_)\n",
    "\n",
    "# use a full grid over all parameters\n",
    "param_grid = Dict(\"max_depth\"=> [3, nothing],\n",
    "                  \"max_features\"=> [1, 3, 10],\n",
    "                  \"min_samples_split\"=> [2, 3, 10],  # min_samples_split=1 became an error in 0.18.1\n",
    "                  \"min_samples_leaf\"=> [1, 3, 10],   \n",
    "                  \"bootstrap\"=>[true, false],\n",
    "                  \"criterion\"=> [\"gini\", \"entropy\"])\n",
    "\n",
    "# run grid search\n",
    "grid_search = GridSearchCV(clf, param_grid)\n",
    "\n",
    "start = time()\n",
    "fit!(grid_search, X, y)\n",
    "\n",
    "@printf(\"GridSearchCV took %.2f seconds for %d candidate parameter settings.\\n\",\n",
    "time() - start, length(grid_search.grid_scores_))\n",
    "\n",
    "report(grid_search.grid_scores_)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "If `refit=true`, `GridSearchCV` and `RandomizedSearchCV` behave like the best estimator they found."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "(predict(random_search,X))[1:10] = [0,1,2,3,4,5,6,7,8,9]\n",
      "(predict_proba(random_search,X))[1:10] = [1.0,0.0,0.02,0.0,0.144167,0.0,0.0125,0.0,0.0,0.0]\n",
      "score(random_search,X,y) = 0.9994435169727324\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "0.9994435169727324"
      ]
     },
     "execution_count": 3,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "@show predict(random_search, X)[1:10]\n",
    "@show predict_proba(random_search, X)[1:10]\n",
    "@show score(random_search, X, y)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Julia 0.5.1",
   "language": "julia",
   "name": "julia-0.5"
  },
  "language_info": {
   "file_extension": ".jl",
   "mimetype": "application/julia",
   "name": "julia",
   "version": "0.5.1"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 1
}
